//   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sstream>
#include <iomanip>
#include <opencv2/dnn.hpp>

#include "common/transforms.h"
#include "common/model_config.h"
#include "det_predictor.h"

namespace ai {


PDDetPredictor::PDDetPredictor(LogInfo *lg):
  log_ifo(lg) {}

int PDDetPredictor::Run(cv::Mat& img,
                      std::vector<BaseInfo*>& det_infos, 
                      PDModelManager* mng,
                      int md_idx) 
{
  if (mng->configs[md_idx]->algorithm == model_det){
      return RunDet(img, det_infos, mng, md_idx);
  }
  return model_not_implement_error;
}
int PDDetPredictor::Run(std::vector<cv::Mat> &imgs,
              std::vector<std::vector<DetInfo>>& det_infos, 
              PDModelManager* mng,
              int md_idx) {
  if (mng->configs[md_idx]->algorithm == model_det){
      return RunDet(imgs, det_infos, mng, md_idx);
  }
  return model_not_implement_error;
}

int PDDetPredictor::RunDet(cv::Mat& img,
                          std::vector<BaseInfo*>& det_infos, 
                          PDModelManager* mng,
                          int md_idx) 
{
  int ret = 0;
  if (log_ifo->log_level_5) spdlog::get("logger")->info("1.2.0.0 rec_img");
  if (img.empty()){return IMAGE_EMPTY_ERROR;}
  
  double t1 = 0;
  if (log_ifo->log_level_4) { t1 = cv::getTickCount(); }
  ModelConfig *cfg = (ModelConfig*)mng->configs[md_idx];
  ImageBlob img_blob = ImageBlob();
  if (!cfg->transforms->run(img, img_blob, cfg)) { return model_image_channels_check_error;}
  if (log_ifo->log_level_5) spdlog::get("logger")->info("1.2.0.1 rec_img");
  if (log_ifo->log_level_4) {
    t1 = ((double)cv::getTickCount() - t1) / cv::getTickFrequency();
    spdlog::get("logger")->info("1.2.0.1 transforms Run Time: {}", t1);
  }

  rec_mut.lock();
  auto input_names = mng->predictors[md_idx]->GetInputNames();
  if (log_ifo->log_level_5) { log_tensor_shape("1.2.0.1 input_names", input_names); }

  for (const auto& tensor_name : input_names) {
    auto in_tensor = mng->predictors[md_idx]->GetInputTensor(tensor_name);
    if (log_ifo->log_level_4) { spdlog::get("logger")->info("1.2.0.1 cur tensor_name: {}", tensor_name); }

    if (tensor_name == "image" || tensor_name == "x0") {
      std::vector<int> shape;
      if (cfg->data_format == "CHW"){ shape = { 1, cfg->channels, img_blob.new_im_shape[0], img_blob.new_im_shape[1] }; }
      else { shape = { 1, img_blob.new_im_shape[0], img_blob.new_im_shape[1], cfg->channels }; }
      if (log_ifo->log_level_4) { log_tensor_shape("1.2.0.1 image shape: ", shape); }
      in_tensor->Reshape(shape);
      if (log_ifo->log_level_4) { log_tensor_shape("1.2.0.1 Tensor image", in_tensor->shape()); }
      if (log_ifo->log_level_4) { spdlog::get("logger")->info("1.2.0.1 img_blob->im_data: {}", img_blob.im_data.size()); }
      
      in_tensor->CopyFromCpu(img_blob.im_data.data());
    } 
    else if (tensor_name == "im_shape") {
      std::vector<int> shape = {1,2};
      in_tensor->Reshape(shape);
      std::vector<float> os = {float(img_blob.ori_im_shape[0]), float(img_blob.ori_im_shape[1])};
      if (log_ifo->log_level_4) { log_tensor_shape("1.2.0.1 im_shape: ", os); }
      
      if (log_ifo->log_level_4) { log_tensor_shape("1.2.0.1 Tensor im_shape", in_tensor->shape()); }
      in_tensor->CopyFromCpu(os.data());

    } 
    else if (tensor_name == "scale_factor") {
      std::vector<int> shape = {1,2};
      in_tensor->Reshape(shape);
      std::vector<float> os = {img_blob.scale, img_blob.scale};
      if (log_ifo->log_level_4) { log_tensor_shape("1.2.0.1 scale_factor: ", os); }
      if (log_ifo->log_level_4) { log_tensor_shape("1.2.0.1 Tensor scale_factor", in_tensor->shape()); }
      in_tensor->CopyFromCpu(os.data());
    }
    if (log_ifo->log_level_4) { spdlog::get("logger")->info("1.2.0.1 cur tensor_name end"); }
  }
  if (log_ifo->log_level_5) spdlog::get("logger")->info("1.2.0.2 rec_img");

  double t2 = 0;
  if (log_ifo->log_level_3) { t2 = cv::getTickCount(); }
  mng->predictors[md_idx]->ZeroCopyRun();
  if (log_ifo->log_level_3) {
    t2 = ((double)cv::getTickCount() - t2) / cv::getTickFrequency();
    spdlog::get("logger")->info("1.2.0.2 predictors Run Time: {}", t2);
  }
  if (log_ifo->log_level_5) spdlog::get("logger")->info("1.2.0.3 rec_img");

  // Get output tensor
  auto output_names = mng->predictors[md_idx]->GetOutputNames();
  if (log_ifo->log_level_5) { log_tensor_shape("1.2.0.3 output_names", output_names); }

  auto out_tensor = mng->predictors[md_idx]->GetOutputTensor(output_names[0]);
  std::vector<int> output_shape = out_tensor->shape();
  if (log_ifo->log_level_5) { log_tensor_shape("1.2.0.3 output_tensor", output_shape); }

  std::vector<float> output_data_2;
  if (cfg->arch == "YOLO_TORCH") {
    auto out_tensor2 = mng->predictors[md_idx]->GetOutputTensor(output_names[1]);
    std::vector<int> output_shape2 = out_tensor2->shape();
    if (log_ifo->log_level_4) { log_tensor_shape("output_tensor2", output_shape2); }
    int size2 = std::accumulate(output_shape2.begin(), output_shape2.end(), 1, std::multiplies<int>());
    output_data_2.resize(size2);
    out_tensor2->CopyToCpu(output_data_2.data());
  }

  // Calculate output length
  int size = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int>());
  if (size < 6) { return model_not_detect_any_object; }
  if (log_ifo->log_level_5) spdlog::get("logger")->info("1.2.0.4 rec_img");

  std::vector<float> output_data(size);
  out_tensor->CopyToCpu(output_data.data()); 
  rec_mut.unlock();

  output_box(output_data, img, img_blob, det_infos, cfg, output_data_2);
  // if (det_infos.size() == 0){ return model_not_detect_any_object; }
  // if (log_ifo->log_core_img) spdlog::get("logger")->info("1.2.0.5 rec_img");

  // if (log_ifo->log_core_img) { spdlog::get("logger")->info("1.2.0.6 det_infos.size(): {}", det_infos.size());}
  if (log_ifo->log_level_4 && det_infos.size() > 0){
    cv::Mat draw_box = visualize(img, det_infos, cfg->label_list);
    cv::imwrite(log_ifo->get_log_path()+"/" \
                              +std::to_string(log_ifo->counter()) \
                              +"_draw_box.jpg", draw_box);
  }
  return 0;
}

int PDDetPredictor::RunDet(std::vector<cv::Mat> &imgs,
                std::vector<std::vector<DetInfo>>& det_infos, 
                PDModelManager* mng,
                int md_idx)
{
  int ret = 0;
  // ModelConfig* cfg = mng->configs[md_idx];

  // auto total_size = imgs.size();
  // int default_batch_size = std::min(cfg->batch_size,
  //                                     static_cast<int>(total_size));
  // int batch = total_size / default_batch_size
  //             + ((total_size % default_batch_size) != 0);

  // for (int u = 0; u < batch; ++u) {
  //   int batch_size = default_batch_size;
  //   if (u == (batch - 1) && (total_size % default_batch_size)) {
  //       batch_size = total_size % default_batch_size;
  //   }

  //   int real_buffer_size = batch_size * cfg->channels * \
  //             cfg->image_shape[0] * cfg->image_shape[1];
  //   std::vector<float> input_buffer(real_buffer_size, 0.);

  //   std::vector<ImageBlob*> data_inputs;
  //   std::vector<cv::Mat&> imgs_batch;
  //   for (int i = 0; i < batch_size; ++i) {
  //       int idx = u * default_batch_size + i;
  //       imgs_batch.push_back(&imgs[idx]);
  //       ImageBlob* ib = new ImageBlob(&imgs[idx]);
  //       ib->im_data = input_buffer.data() + i * cfg->channels \
  //               * cfg->image_shape[0] * cfg->image_shape[1];
  //       data_inputs.push_back(ib);
  //   }

  //   int64 start = cv::getTickCount();
  //   ret = det_pp_->batch_process(imgs_batch, data_inputs, cfg);

  //   // Prepare input tensor
  //   auto input_names = mng->predictors[md_idx]->GetInputNames();
  //   for (const auto& tensor_name : input_names) {
  //     auto in_tensor = mng->predictors[md_idx]->GetInputHandle(tensor_name);
  //     if (tensor_name == "image") {
  //       in_tensor->Reshape({batch_size, cfg->channels, 
  //                           cfg->image_shape[0], cfg->image_shape[1]});
  //       in_tensor->CopyFromCpu(input_buffer.data());
  //     } else if (tensor_name == "im_size") {
  //       in_tensor->Reshape({batch_size, 2});
  //       std::vector<int> batch_im_shape;
  //       for (auto di : data_inputs) { batch_im_shape.insert(batch_im_shape.begin(), 
  //         di->ori_im_shape.begin(), di->ori_im_shape.end());}
  //       in_tensor->CopyFromCpu(batch_im_shape.data());
  //     } else if (tensor_name == "im_shape") {
  //       in_tensor->Reshape({1, 3});
  //       std::vector<float> os;
  //       os.push_back(float(data_inputs[0]->ori_im_shape[0]));
  //       os.push_back(float(data_inputs[0]->ori_im_shape[1]));
  //       os.push_back(1.0);
  //       in_tensor->CopyFromCpu(os.data());
  //     } else if (tensor_name == "scale_factor") {
  //       in_tensor->Reshape({1, 2});
  //       in_tensor->CopyFromCpu(data_inputs[0]->scale_factor.data());
  //     }
  //   }

  //   mng->predictors[md_idx]->Run();
  //   // Get output tensor
  //   auto output_names = mng->predictors[md_idx]->GetOutputNames();
  //   auto out_tensor = mng->predictors[md_idx]->GetOutputHandle(output_names[0]);
  //   std::vector<int> output_shape = out_tensor->shape();
  //   if (log_ifo->log_level_4) {
  //     spdlog::get("logger")->info("output_names: ");
  //     for (auto n : output_names) { spdlog::get("logger")->info(n << " ");}
  //     spdlog::get("logger")->info(std::endl);
  //     spdlog::get("logger")->info("output_shape: ");
  //     for ( auto x : output_shape) { spdlog::get("logger")->info(x << " ");}
  //     spdlog::get("logger")->info(std::endl);
  //   }
  //   // Calculate output length
  //   int output_size = 1;
  //   for (int j = 0; j < output_shape.size(); ++j) {
  //     output_size *= output_shape[j];
  //   }

  //   if (output_size < 6) {
  //     std::cerr << "[WARNING] No object detected." << std::endl;
  //   }
  //   std::vector<float> output_data(output_size);
  //   out_tensor->CopyToCpu(output_data.data()); 

  //   if (log_ifo->log_level_4) {
  //     int64 end = cv::getTickCount();
  //     float ms = static_cast<float>(end - start) / cv::getTickFrequency()  * 1000.;
  //     printf("Inference: %f ms per batch image\n", ms);
  //   }
  //   // Postprocessing result
  //   output_box(output_data, imgs,  data_inputs, det_infos, cfg);

  // }
   return 0;
}

void PDDetPredictor::output_box(
    std::vector<float>& output_data,
    cv::Mat& img,
    ImageBlob& img_info,
    std::vector<BaseInfo*>& det_infos,
    ModelConfig* cfg,
    const std::vector<float>& bbox) 
{
  if (log_ifo->log_level_5) spdlog::get("logger")->info("1.2.0.4.0 rec_img");

  int rh = 1;
  int rw = 1;
  std::vector<cv::Rect> boxs;
  std::vector<float> scores;
  std::vector<int> labels;
  float nms_threshold;

  if (cfg->arch == "YOLO_TORCH") {
    nms_threshold = 0.6;
    float w_ratio = img_info.ori_im_shape[1] * 1.0 / img_info.new_im_shape[1];
    float h_ratio = img_info.ori_im_shape[0] * 1.0 / img_info.new_im_shape[0];
    // log_tensor_shape("ori_im_shape", img_info->ori_im_shape);
    // log_tensor_shape("new_im_shape", img_info->new_im_shape);

    int class_num = cfg->label_list.size();
    int score_size = output_data.size() / class_num;
    int box_size = bbox.size() / 4;
    if (score_size != box_size) { return;} 

    for (int i = 0; i < box_size; i++) {
      auto it = std::max_element(output_data.begin() + i * class_num, output_data.begin() + (i+1) * class_num);
      if (*it < cfg->draw_threshold) { continue; }
      int w = bbox[i * 4 + 2] * img_info.ori_im_shape[1];
      int h = bbox[i * 4 + 3] * img_info.ori_im_shape[0];
      int x = bbox[i * 4 + 0] * img_info.ori_im_shape[1] - w / 2;
      int y = bbox[i * 4 + 1] * img_info.ori_im_shape[0] - h / 2;

      if (x < 0) { w += x; x = 0;}
      if (x + w > img.cols) { w = img.cols - x;}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           
      if (y < 0) { h += y; y = 0;}
      if (y + h > img.rows) { h = img.rows - y;}
      boxs.push_back(cv::Rect(x,y,w,h));
      scores.push_back(*it);
      labels.push_back(std::distance(output_data.begin() + i * class_num, it));
    }
  }
  else if (cfg->arch == "YOLOV5") {
    nms_threshold = 0.6;
    // float w_ratio = img_info->ori_im_shape[1] * 1.0 / img_info->new_im_shape[1];
    // float h_ratio = img_info->ori_im_shape[0] * 1.0 / img_info->new_im_shape[0];

    int class_num = cfg->label_list.size();
    int step = class_num+5;
    int box_size = output_data.size() / step;
    if (log_ifo->log_level_5) spdlog::get("logger")->info("1.2.0.4.0 class_num:{}, step:{} box_size:{}", class_num, step, box_size);

    for (int i = 0; i < box_size; i++) {
      auto it = std::max_element(output_data.begin() + i*step+5, output_data.begin() + i*step+step);
      float score = *it * output_data[i*step + 4];

      if (score < cfg->draw_threshold) { continue; }
      // if (score > 0.8) {
      //   std::cout << output_data[i*step + 0] * w_ratio<< "  " << output_data[i*step + 1]*h_ratio << "  " << output_data[i*step + 2] * w_ratio << "  " << output_data[i*step + 3] *h_ratio<< "  " << output_data[i*step + 4] << "  " << output_data[i*step + 5] << "  " << output_data[i*step + 6] << "  " << output_data[i*step + 7] << "  " << output_data[i*step + 8] << "  " << output_data[i*step + 9] << "  " << output_data[i*step + 10] << "  " << output_data[i*step + 11]<< std::endl;
      // }
      // std::cout << "score: " << score << std::endl;

      float w = output_data[i*step + 2] / img_info.scale;
      float h = output_data[i*step + 3] / img_info.scale;
      float x = output_data[i*step + 0] / img_info.scale - w/2;
      float y = output_data[i*step + 1] / img_info.scale - h/2;
      // std::cout << "(" << x << "," << y << "," << w<<","<<h<<") "  << std::endl;

      if (x < 0) { w += x; x = 0;}
      if (x + w > img.cols) { w = img.cols - x;}                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           
      if (y < 0) { h += y; y = 0;}
      if (y + h > img.rows) { h = img.rows - y;}
      boxs.push_back(cv::Rect(x,y,w,h));
      scores.push_back(score);
      labels.push_back(std::distance(output_data.begin() + i*step+5, it));
    }
  }
  else {
    nms_threshold = 0.3;
    int box_size = output_data.size() / 6;
    for (int j = 0; j < box_size; ++j) {
      // Class id
      int class_id = static_cast<int>(round(output_data[0 + j * 6]));
      // Confidence score
      float score = output_data[1 + j * 6];
      int xmin = (output_data[2 + j * 6] * rw);
      int ymin = (output_data[3 + j * 6] * rh);
      int xmax = (output_data[4 + j * 6] * rw);
      int ymax = (output_data[5 + j * 6] * rh);

      if (score > cfg->draw_threshold && class_id > -1) {
        if (xmin < 0) { xmin = 0;}
        if (ymin < 0) { ymin = 0;}
        if (xmax > img.cols) { xmax = img.cols;}
        if (ymax > img.rows) { ymax = img.rows;}
        int wd = xmax - xmin;
        int hd = ymax - ymin;

        boxs.push_back(cv::Rect(xmin, ymin, wd, hd));
        scores.push_back(score);
        labels.push_back(class_id);
      }
    }
  }
  if (log_ifo->log_level_5) spdlog::get("logger")->info("1.2.0.4.1 boxs:{}", boxs.size());

  std::vector<int> indices;
  cv::dnn::NMSBoxes(boxs, scores, cfg->draw_threshold, nms_threshold, indices, 100);
  for (auto& idx : indices) {
    DetInfo* di = new DetInfo(scores[idx], labels[idx], boxs[idx], cfg->label_list[labels[idx]]);
    det_infos.push_back(di);
  }
  if (log_ifo->log_level_5) spdlog::get("logger")->info("1.2.0.4.2 det_infos:{}", det_infos.size());

}

void PDDetPredictor::output_box(
    std::vector<float>& output_data,
    std::vector<cv::Mat>& imgs,
    std::vector<ImageBlob>& img_info,
    std::vector<std::vector<DetInfo>>& det_infos,
    ModelConfig* cfg) 
{
  int rh = 1;
  int rw = 1;

  // int total_size = output_data.size() / 6;
  // if (total_size <= 0) {
  //   det_infos.resize(imgs.size());
  //   return ;
  // }
  // int batch_idx = -1;
  // float last_score = 0.;
  // std::vector<DetInfo> det_info;
  // for (int j = 0; j < total_size; ++j) {
  //   // Class id
  //   int class_id = static_cast<int>(round(output_data[0 + j * 6]));
  //   // Confidence score
  //   float score = output_data[1 + j * 6];
  //   if (score > last_score) {
  //     if (j != 0) { det_infos.push_back(det_info); }
  //     det_info.clear();
  //     batch_idx++;
  //     if (cfg->arch == "SSD" || cfg->arch == "Face") {
  //       rh = imgs[batch_idx].rows;
  //       rw = imgs[batch_idx].cols;
  //     }
  //   }
  //   last_score = score;

  //   int xmin = (output_data[2 + j * 6] * rw);
  //   int ymin = (output_data[3 + j * 6] * rh);
  //   int xmax = (output_data[4 + j * 6] * rw);
  //   int ymax = (output_data[5 + j * 6] * rh);
  //   if (log_ifo->log_level_4) {
  //   }

  //   if (score > cfg->draw_threshold && class_id > -1) {
  //     if (xmin < 0) { xmin = 0;}
  //     if (ymin < 0) { ymin = 0;}
  //     if (xmax > imgs[batch_idx].cols) { xmax = imgs[batch_idx].cols;}
  //     if (ymax > imgs[batch_idx].rows) { ymax = imgs[batch_idx].rows;}
  //     int wd = xmax - xmin;
  //     int hd = ymax - ymin;

  //     cv::Rect box = cv::Rect(xmin, ymin, wd, hd);

  //     det_info.push_back(DetInfo(class_id,score, box));
  //   }
  // }
  // det_infos.push_back(det_info); 
}

}  // namespace PaddleDetection
